"""
Phase 1: Data Assembly — Safe Asset Cliff Panel
=================================================
Merges safe_asset_panel and fiscal_panel for 31 rated countries.
Constructs event variables (downgrade_any, lost_safe, lost_aaa),
fiscal stress variables, forward demographics, OADR spline.

Output: safe_asset_cliff/data/processed/cliff_panel.csv
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
SAFE_DIR = ROOT_DIR / "safe_assets"
FISCAL_DIR = ROOT_DIR / "fiscal_dominance"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# ── Import RATING_HISTORY from safe_assets ─────────────────────────────────
sys.path.insert(0, str(SAFE_DIR / "scripts"))
from phase1_data_assembly import RATING_HISTORY, RATING_SCALE, SAFE_THRESHOLD, build_ratings_panel

REVERSE_SCALE = {v: k for k, v in RATING_SCALE.items()}


def rating_category(rating_numeric):
    """Map numeric rating to ordinal category: AAA=4, AA+=3, AA=2, AA-=1, below=0."""
    if pd.isna(rating_numeric):
        return np.nan
    r = int(rating_numeric)
    if r >= 21:
        return 4  # AAA
    elif r == 20:
        return 3  # AA+
    elif r == 19:
        return 2  # AA
    elif r == 18:
        return 1  # AA-
    else:
        return 0  # below AA-


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Safe Asset Cliff Panel")
    print("=" * 70)

    # ── [1] Build ratings panel ──
    print("\n[1] Building ratings panel ...")
    years = list(range(1990, 2025))
    ratings = build_ratings_panel(years)
    rated_countries = sorted(ratings['iso3'].unique())
    print(f"  Rated countries: {len(rated_countries)}")
    print(f"  Ratings obs: {len(ratings):,}")

    # ── [2] Load safe_asset_panel ──
    print("\n[2] Loading safe_asset_panel.csv ...")
    safe_path = SAFE_DIR / "data" / "processed" / "safe_asset_panel.csv"
    safe_df = pd.read_csv(safe_path)
    # Restrict to rated countries
    safe_df = safe_df[safe_df['iso3'].isin(rated_countries)].copy()
    print(f"  Safe panel (rated only): {safe_df['iso3'].nunique()} countries, "
          f"{len(safe_df):,} obs")

    # ── [3] Load fiscal_panel ──
    print("\n[3] Loading fiscal_panel.csv ...")
    fiscal_path = FISCAL_DIR / "data" / "processed" / "fiscal_panel.csv"
    fisc = pd.read_csv(fiscal_path)
    fisc = fisc[fisc['iso3'].isin(rated_countries)].copy()
    print(f"  Fiscal panel (rated only): {fisc['iso3'].nunique()} countries, "
          f"{len(fisc):,} obs")

    # ── [4] Merge panels ──
    print("\n[4] Merging panels ...")
    # Start from safe_asset_panel (has ratings, macros, demographics)
    # Add fiscal-specific vars not already present
    fiscal_vars = ['govt_debt_gdp', 'govt_net_debt_gdp', 'primary_bal_gdp',
                   'structural_bal_gdp', 'govt_revenue_gdp', 'govt_expenditure_gdp',
                   'r_minus_g', 'r_minus_g_real', 'debt_lag', 'debt_change',
                   'oadr_plus20']
    fiscal_cols = ['iso3', 'year'] + [c for c in fiscal_vars if c in fisc.columns]
    fisc_merge = fisc[fiscal_cols].drop_duplicates(subset=['iso3', 'year'])

    # Identify columns already in safe_df to avoid duplicates
    existing = set(safe_df.columns)
    merge_cols = ['iso3', 'year'] + [c for c in fiscal_vars
                                      if c in fisc_merge.columns and c not in existing]
    if len(merge_cols) > 2:
        df = safe_df.merge(fisc_merge[merge_cols], on=['iso3', 'year'], how='left')
    else:
        df = safe_df.copy()

    # Ensure rating columns from ratings panel
    if 'rating_numeric' not in df.columns or df['rating_numeric'].isna().all():
        df = df.drop(columns=['rating_numeric', 'safe_issuer'], errors='ignore')
        df = df.merge(ratings[['iso3', 'year', 'rating_numeric', 'safe_issuer']],
                      on=['iso3', 'year'], how='left')
    df['safe_issuer'] = df['safe_issuer'].fillna(0).astype(int)

    print(f"  Merged panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [5] Construct event variables ──
    print("\n[5] Constructing event variables ...")
    df = df.sort_values(['iso3', 'year'])

    # Lagged rating
    df['rating_lag'] = df.groupby('iso3')['rating_numeric'].shift(1)

    # downgrade_any: rating decreased from prior year
    df['downgrade_any'] = ((df['rating_numeric'] < df['rating_lag']) &
                            df['rating_numeric'].notna() &
                            df['rating_lag'].notna()).astype(int)

    # downgrade_notch: magnitude
    df['downgrade_notch'] = np.where(
        df['downgrade_any'] == 1,
        df['rating_lag'] - df['rating_numeric'],
        0
    )

    # lost_safe: fell below AA- (18) from safe status
    df['lost_safe'] = ((df['rating_lag'] >= SAFE_THRESHOLD) &
                        (df['rating_numeric'] < SAFE_THRESHOLD) &
                        df['rating_numeric'].notna()).astype(int)

    # lost_aaa: fell below AAA (21)
    df['lost_aaa'] = ((df['rating_lag'] >= 21) &
                       (df['rating_numeric'] < 21) &
                       df['rating_numeric'].notna()).astype(int)

    # rating_category: ordered
    df['rating_category'] = df['rating_numeric'].apply(rating_category)

    n_downgrades = df['downgrade_any'].sum()
    n_lost_safe = df['lost_safe'].sum()
    n_lost_aaa = df['lost_aaa'].sum()
    print(f"  downgrade_any events: {n_downgrades}")
    print(f"  lost_safe events: {n_lost_safe}")
    print(f"  lost_aaa events: {n_lost_aaa}")

    # ── [6] Fiscal stress variables ──
    print("\n[6] Constructing fiscal stress variables ...")

    # exp_rev_gap
    if 'govt_expenditure_gdp' in df.columns and 'govt_revenue_gdp' in df.columns:
        df['exp_rev_gap'] = df['govt_expenditure_gdp'] - df['govt_revenue_gdp']
        print(f"  exp_rev_gap: {df['exp_rev_gap'].notna().sum():,} non-null, "
              f"mean={df['exp_rev_gap'].mean():.2f}")

    # debt_change_5y: 5-year cumulative change in debt/GDP
    if 'govt_debt_gdp' in df.columns:
        df['debt_lag5'] = df.groupby('iso3')['govt_debt_gdp'].shift(5)
        df['debt_change_5y'] = df['govt_debt_gdp'] - df['debt_lag5']
        print(f"  debt_change_5y: {df['debt_change_5y'].notna().sum():,} non-null")

    # primary_gap: deviation from Bohn-implied sustainable level
    # Bohn rule: primary surplus should respond positively to debt
    # We approximate as primary_bal_gdp - 0.02 * (debt/GDP - 60)
    if 'primary_bal_gdp' in df.columns and 'govt_debt_gdp' in df.columns:
        df['primary_gap'] = df['primary_bal_gdp'] - 0.02 * (df['govt_debt_gdp'] - 60)
        print(f"  primary_gap: {df['primary_gap'].notna().sum():,} non-null")

    # ── [7] Forward demographics ──
    print("\n[7] Forward demographics ...")

    # oadr_plus10: OADR 10 years ahead (approximate from same-country future obs)
    if 'old_dep' in df.columns:
        df['oadr_plus10'] = df.groupby('iso3')['old_dep'].shift(-10)
        df['oadr_plus20_fwd'] = df.groupby('iso3')['old_dep'].shift(-20)
        print(f"  oadr_plus10: {df['oadr_plus10'].notna().sum():,} non-null")
        print(f"  oadr_plus20_fwd: {df['oadr_plus20_fwd'].notna().sum():,} non-null")

        # Also check if oadr_plus20 from fiscal panel is available
        if 'oadr_plus20' in df.columns:
            print(f"  oadr_plus20 (from fiscal panel): {df['oadr_plus20'].notna().sum():,} non-null")

    # ── [8] OADR spline ──
    print("\n[8] OADR spline variables ...")
    if 'old_dep' in df.columns:
        for knot in [15, 20, 25, 30]:
            df[f'oadr_spline_{knot}'] = np.maximum(0, df['old_dep'] - knot / 100.0)
        print(f"  Created oadr_spline_15, oadr_spline_20, oadr_spline_25, oadr_spline_30")

    # ── [9] Lagged fiscal variables ──
    print("\n[9] Lagged fiscal variables ...")
    for var in ['exp_rev_gap', 'govt_debt_gdp', 'primary_bal_gdp', 'primary_gap']:
        if var in df.columns:
            df[f'{var}_lag'] = df.groupby('iso3')[var].shift(1)
    if 'debt_change_5y' in df.columns:
        df['debt_change_5y_lag'] = df.groupby('iso3')['debt_change_5y'].shift(1)

    # ── [10] Global OADR ──
    print("\n[10] Global OADR ...")
    if 'global_oadr' in df.columns:
        print(f"  global_oadr (from safe_asset_panel): {df['global_oadr'].notna().sum():,} non-null")
    elif 'old_dep' in df.columns and 'ngdp_usd' in df.columns:
        full_panel_path = ROOT_DIR / "multilateral" / "followup" / "data" / "processed" / "full_panel.csv"
        if full_panel_path.exists():
            fp = pd.read_csv(full_panel_path, usecols=['iso3', 'year', 'old_dep', 'ngdp_usd'])
            fp = fp.dropna(subset=['old_dep', 'ngdp_usd'])
            fp['w_oadr'] = fp['old_dep'] * fp['ngdp_usd']
            global_oadr = fp.groupby('year').agg(
                w_sum=('w_oadr', 'sum'), gdp_sum=('ngdp_usd', 'sum')
            ).reset_index()
            global_oadr['global_oadr'] = global_oadr['w_sum'] / global_oadr['gdp_sum']
            df = df.merge(global_oadr[['year', 'global_oadr']], on='year', how='left')
            print(f"  global_oadr: {df['global_oadr'].notna().sum():,} non-null")

    # ── [11] Restrict to rated countries with data ──
    print("\n[11] Restricting to rated countries with rating data ...")
    df = df[df['rating_numeric'].notna()].copy()
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()
    print(f"  Final panel: {df['iso3'].nunique()} countries, {len(df):,} obs, "
          f"{df['year'].min()}-{df['year'].max()}")

    # ── [12] Save ──
    print("\n[12] Saving cliff_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "cliff_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'cliff_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    # ── [13] Verification ──
    print("\n[13] Verification ...")
    print(f"  Countries: {sorted(df['iso3'].unique())}")
    print(f"  Total downgrade events: {df['downgrade_any'].sum()}")
    print(f"  Lost safe events: {df['lost_safe'].sum()}")
    print(f"  Lost AAA events: {df['lost_aaa'].sum()}")

    # List specific downgrade events
    downgrades = df[df['downgrade_any'] == 1][['iso3', 'year', 'rating_lag', 'rating_numeric',
                                                 'downgrade_notch']].copy()
    downgrades['from_rating'] = downgrades['rating_lag'].map(
        lambda x: REVERSE_SCALE.get(int(x), '?') if pd.notna(x) else '?')
    downgrades['to_rating'] = downgrades['rating_numeric'].map(
        lambda x: REVERSE_SCALE.get(int(x), '?') if pd.notna(x) else '?')
    print(f"\n  All downgrade events ({len(downgrades)}):")
    for _, row in downgrades.iterrows():
        print(f"    {row['iso3']} {int(row['year'])}: "
              f"{row['from_rating']} → {row['to_rating']} "
              f"(-{int(row['downgrade_notch'])} notch)")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
